-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for [[clad::non_differentiable]]
in reverse mode
#916
Add support for [[clad::non_differentiable]]
in reverse mode
#916
Conversation
clang-tidy review says "All clean, LGTM! 👍" |
If we mark that operator overload as non-differentiable, then should the issue still happen? |
I just tried changing the test and it still errors, but this time it is |
Oh, yes, you are right. Thank you for the details. |
e3fe735
to
22e5c11
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clang-tidy made some suggestions
.get(); | ||
// Creating a zero derivative | ||
auto* zero = | ||
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warning: argument comment missing for literal argument 'val' [bugprone-argument-comment]
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0); | |
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, /*val=*/0); |
@@ -2867,6 +2913,10 @@ | |||
"CXXMethodDecl nodes not supported yet!"); | |||
MemberExpr* clonedME = utils::BuildMemberExpr( | |||
m_Sema, getCurrentScope(), baseDiff.getExpr(), field->getName()); | |||
auto zero = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warning: 'auto zero' can be declared as 'auto *zero' [llvm-qualified-auto]
auto zero = | |
auto *zero = |
@@ -2867,6 +2913,10 @@ | |||
"CXXMethodDecl nodes not supported yet!"); | |||
MemberExpr* clonedME = utils::BuildMemberExpr( | |||
m_Sema, getCurrentScope(), baseDiff.getExpr(), field->getName()); | |||
auto zero = | |||
ConstantFolder::synthesizeLiteral(m_Context.DoubleTy, m_Context, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warning: argument comment missing for literal argument 'val' [bugprone-argument-comment]
ConstantFolder::synthesizeLiteral(m_Context.DoubleTy, m_Context, 0); | |
ConstantFolder::synthesizeLiteral(m_Context.DoubleTy, m_Context, /*val=*/0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clang-tidy made some suggestions
if (condVarResult.getDecl_dx()) | ||
addToCurrentBlock(BuildDeclStmt(condVarResult.getDecl_dx())); | ||
auto condInit = condVarClone->getInit(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warning: 'auto condInit' can be declared as 'auto *condInit' [llvm-qualified-auto]
auto condInit = condVarClone->getInit(); | |
auto *condInit = condVarClone->getInit(); |
@MihailMihov, can you rebase this pull request? |
530a10d
to
34c325b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you take care of the clang-tidy suggestions and adapt the new test in such a way that we work around the pre-existing issue of #917?
aae40ae
to
105982c
Compare
clang-tidy review says "All clean, LGTM! 👍" |
105982c
to
7b63512
Compare
clang-tidy review says "All clean, LGTM! 👍" |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #916 +/- ##
==========================================
+ Coverage 93.92% 93.94% +0.01%
==========================================
Files 55 55
Lines 8038 8061 +23
==========================================
+ Hits 7550 7573 +23
Misses 488 488
|
7b63512
to
7477596
Compare
clang-tidy review says "All clean, LGTM! 👍" |
@@ -2954,6 +2977,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, | |||
"CXXMethodDecl nodes not supported yet!"); | |||
MemberExpr* clonedME = utils::BuildMemberExpr( | |||
m_Sema, getCurrentScope(), baseDiff.getExpr(), field->getName()); | |||
auto* zero = ConstantFolder::synthesizeLiteral(m_Context.DoubleTy, | |||
m_Context, /*val=*/0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An IntTy
would be more suitable here because we might need to zero-initialize pointers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also saw that this is 'DoubleTy` in forward mode. Can you also test that and fix that in a separate PR, if possible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did change the DoubleTy
to an IntTy
, but that didn't actually do anything to the tests and pointers didn't work either way. I added another check for them when visiting UO_Deref
, not sure if there isn't a better way to fix them however, but now the new test is passing.
SimpleFunctions1() noexcept : x(0), y(0) {} | ||
SimpleFunctions1(double p_x, double p_y) noexcept : x(p_x), y(p_y) {} | ||
double x; | ||
non_differentiable double y; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also test with some pointer member types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just added an fn_s1_field_pointer
to the test, is that what you wanted me to test?
f731f1d
to
dca4fad
Compare
clang-tidy review says "All clean, LGTM! 👍" |
// Calling the function without computing derivatives | ||
llvm::SmallVector<Expr*, 4> ClonedArgs; | ||
for (unsigned i = 0, e = CE->getNumArgs(); i < e; ++i) | ||
ClonedArgs.push_back(Clone(CE->getArg(i))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simply cloning the argument seems incorrect. What if the arguments have side-effect which can affect the derivative computation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure that I understand the issue here. It the arguments do have side effects then those would be kept when we clone them, is that not what is expected? When do you think that this wouldn't work correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider an example such as this:
some_non_differentiable_fn_call(r = u * v, s = u + v);
Now, if we simply clone the arguments then we will not generate adjoint statements for r = u * v
and s = u + v
.
You don't necessarily need to fix this issue in this PR.
|
||
// If we have a pointer to a member expression, which is | ||
// non-differentiable, we just return a clone of the original expression. | ||
if (auto* ME = dyn_cast<MemberExpr>(diff.getExpr())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can it be handled more uniformly in VisitMemberExpr
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If diff.getExpr_dx()
is 0
, then we would not need to add a special condition here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can it be handled more uniformly in
VisitMemberExpr
?
It is already handled in VisitMemberExpr
where we return {clonedME, zero}
, but what happened without the above check is that it tries to build something along the lines of *0 += ...
, when visiting the UO_Deref
.
If
diff.getExpr_dx()
is0
, then we would need to add a special condition here.
With the above check I eliminate one of the cases where we could end up with a 0 above, if you can think of anything else, then we should handle those too. Do you have anything in mind?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you can think of anything else, then we should handle those too.
I am concerned there might be many such cases... For example, ->
operator.
It might be better to test if the diff.getExpr_dx()
is a constant (or 0
) instead of testing if the member has a non-differentiable attribute. This is because it will help us cover more cases. For example, the adjoint of member expressions of global class objects should also be 0
and consequently they should be handled similarly but they do not have non_differentiable
attribute.
@@ -2954,6 +2984,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, | |||
"CXXMethodDecl nodes not supported yet!"); | |||
MemberExpr* clonedME = utils::BuildMemberExpr( | |||
m_Sema, getCurrentScope(), baseDiff.getExpr(), field->getName()); | |||
auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is better to create zero
inside the if-condition as it is only used there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have changed that, also what do you think about this being IntTy
or DoubleTy
. In the forward mode code it was DoubleTy
, but Vaibhav suggested changing it to IntTy
. It didn't seem to make any difference, but maybe somewhere it will?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it will make a difference anywhere. Clang automatically adds cast nodes to convert 0
to the right type.
// of lambdas is happening in the `VisitCallExpr`. For now, only the | ||
// declarations with lambda expressions without captures are supported. | ||
isLambda = typeDecl && typeDecl->isLambda(); | ||
if (isLambda || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add a test for a local variable declaration with non_differentiable attribute?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just added fn_non_diff_var
to Gradient/NonDifferentiable.C test, but I believe it's not working as expected. The correct output would be 0.00 0.00 right? I'll try to get that fixed now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please create an issue for this and resolve it in a follow-up pull-request?
84d2de1
to
635e36d
Compare
clang-tidy review says "All clean, LGTM! 👍" |
0083191
to
cd177a4
Compare
clang-tidy review says "All clean, LGTM! 👍" |
cd177a4
to
f4780e9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good as an initial support for [[non_differentiable]]
. Thank you for working on this.
Can you please open issues for the comments which need some future work:
- Differentiating argument expressions of non differentiable function calls. This is required because argument expressions can have side-affects.
- Supporting non differentiate attribute on local variables.
- Improve handling of non_differentiable variables in expressions such as
*
(dereference operator),->
and so on.
Can you squash the tests into the other commit? |
clang-tidy review says "All clean, LGTM! 👍" |
f4780e9
to
bd386d6
Compare
clang-tidy review says "All clean, LGTM! 👍" |
I opened issues for 1 and 3. For 2 this PR includes a basic test and fix, but more work may be necessary. Do I just create an issue saying that there might be a cleaner fix or something more specific? |
bd386d6
to
f7c65a8
Compare
clang-tidy review says "All clean, LGTM! 👍" |
Yes, we should create an issue describing what this "more work may be necessary" means in technical and practical terms. |
f7c65a8
to
b590586
Compare
clang-tidy review says "All clean, LGTM! 👍" |
fixes #717
I've added the code for handing
[[clad::non_differentiable]]
that is already present in the forward mode visitor to the one for reverse mode. I also modified the tests from forward mode, butReverseMode/NonDifferentiable.C
is currently failing because of an issue with differentiating operator overloads in reverse mode.